{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "064f83ff-e8ff-41c5-b992-3bb8d7e4a16f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from nltk.corpus import stopwords\n",
    "from contextualized_topic_models.utils.preprocessing import WhiteSpacePreprocessingStopwords\n",
    "from contextualized_topic_models.models.ctm import ZeroShotTM\n",
    "from contextualized_topic_models.utils.data_preparation import TopicModelDataPreparation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "57cc68ab-5958-48e0-b7a8-dfd5bc18ed3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "warnings.filterwarnings(\"ignore\", category = DeprecationWarning)\n",
    "import os\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1d4d77b8-96ee-4754-9618-31f497049cfc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "           category                                               text\n",
      "0              tech  tv future in the hands of viewers with home th...\n",
      "1          business  worldcom boss  left books alone  former worldc...\n",
      "2             sport  tigers wary of farrell  gamble  leicester say ...\n",
      "3             sport  yeading face newcastle in fa cup premiership s...\n",
      "4     entertainment  ocean s twelve raids box office ocean s twelve...\n",
      "...             ...                                                ...\n",
      "2220       business  cars pull down us retail figures us retail sal...\n",
      "2221       politics  kilroy unveils immigration policy ex-chatshow ...\n",
      "2222  entertainment  rem announce new glasgow concert us band rem h...\n",
      "2223       politics  how political squabbles snowball it s become c...\n",
      "2224          sport  souness delight at euro progress boss graeme s...\n",
      "\n",
      "[2225 rows x 2 columns]\n"
     ]
    }
   ],
   "source": [
    "stop_words = stopwords.words('english')\n",
    "stop_words.append(\"said\")\n",
    "bbc_df = pd.read_csv(\"../data/bbc-text.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "81c2c116-9712-4515-9266-35bed4da05af",
   "metadata": {},
   "outputs": [],
   "source": [
    "documents = bbc_df[\"text\"]\n",
    "preprocessor = WhiteSpacePreprocessingStopwords(documents, stopwords_list=stop_words) \n",
    "preprocessed_documents, unpreprocessed_documents, vocab, indices = preprocessor.preprocess() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "6c97fa76-61ac-4352-87ec-a5ae2e0a55fb",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zhenya/.cache/pypoetry/virtualenvs/book-2nd-ed-tI3JII_H-py3.10/lib/python3.10/site-packages/ipywidgets/widgets/widget.py:438: DeprecationWarning: The `ipykernel.comm.Comm` class has been deprecated. Please use the `comm` module instead.For creating comms, use the function `from comm import create_comm`.\n",
      "  self.comm = Comm(**args)\n",
      "/home/zhenya/.cache/pypoetry/virtualenvs/book-2nd-ed-tI3JII_H-py3.10/lib/python3.10/site-packages/ipywidgets/widgets/widget.py:438: DeprecationWarning: The `ipykernel.comm.Comm` class has been deprecated. Please use the `comm` module instead.For creating comms, use the function `from comm import create_comm`.\n",
      "  self.comm = Comm(**args)\n",
      "/home/zhenya/.cache/pypoetry/virtualenvs/book-2nd-ed-tI3JII_H-py3.10/lib/python3.10/site-packages/ipywidgets/widgets/widget.py:438: DeprecationWarning: The `ipykernel.comm.Comm` class has been deprecated. Please use the `comm` module instead.For creating comms, use the function `from comm import create_comm`.\n",
      "  self.comm = Comm(**args)\n",
      "/home/zhenya/.cache/pypoetry/virtualenvs/book-2nd-ed-tI3JII_H-py3.10/lib/python3.10/site-packages/ipywidgets/widgets/widget.py:438: DeprecationWarning: The `ipykernel.comm.Comm` class has been deprecated. Please use the `comm` module instead.For creating comms, use the function `from comm import create_comm`.\n",
      "  self.comm = Comm(**args)\n",
      "/home/zhenya/.cache/pypoetry/virtualenvs/book-2nd-ed-tI3JII_H-py3.10/lib/python3.10/site-packages/ipywidgets/widgets/widget.py:438: DeprecationWarning: The `ipykernel.comm.Comm` class has been deprecated. Please use the `comm` module instead.For creating comms, use the function `from comm import create_comm`.\n",
      "  self.comm = Comm(**args)\n",
      "/home/zhenya/.cache/pypoetry/virtualenvs/book-2nd-ed-tI3JII_H-py3.10/lib/python3.10/site-packages/ipywidgets/widgets/widget.py:438: DeprecationWarning: The `ipykernel.comm.Comm` class has been deprecated. Please use the `comm` module instead.For creating comms, use the function `from comm import create_comm`.\n",
      "  self.comm = Comm(**args)\n",
      "/home/zhenya/.cache/pypoetry/virtualenvs/book-2nd-ed-tI3JII_H-py3.10/lib/python3.10/site-packages/ipywidgets/widgets/widget.py:438: DeprecationWarning: The `ipykernel.comm.Comm` class has been deprecated. Please use the `comm` module instead.For creating comms, use the function `from comm import create_comm`.\n",
      "  self.comm = Comm(**args)\n",
      "/home/zhenya/.cache/pypoetry/virtualenvs/book-2nd-ed-tI3JII_H-py3.10/lib/python3.10/site-packages/ipywidgets/widgets/widget.py:438: DeprecationWarning: The `ipykernel.comm.Comm` class has been deprecated. Please use the `comm` module instead.For creating comms, use the function `from comm import create_comm`.\n",
      "  self.comm = Comm(**args)\n",
      "/home/zhenya/.cache/pypoetry/virtualenvs/book-2nd-ed-tI3JII_H-py3.10/lib/python3.10/site-packages/ipywidgets/widgets/widget.py:438: DeprecationWarning: The `ipykernel.comm.Comm` class has been deprecated. Please use the `comm` module instead.For creating comms, use the function `from comm import create_comm`.\n",
      "  self.comm = Comm(**args)\n",
      "/home/zhenya/.cache/pypoetry/virtualenvs/book-2nd-ed-tI3JII_H-py3.10/lib/python3.10/site-packages/ipywidgets/widgets/widget.py:438: DeprecationWarning: The `ipykernel.comm.Comm` class has been deprecated. Please use the `comm` module instead.For creating comms, use the function `from comm import create_comm`.\n",
      "  self.comm = Comm(**args)\n",
      "/home/zhenya/.cache/pypoetry/virtualenvs/book-2nd-ed-tI3JII_H-py3.10/lib/python3.10/site-packages/ipywidgets/widgets/widget.py:438: DeprecationWarning: The `ipykernel.comm.Comm` class has been deprecated. Please use the `comm` module instead.For creating comms, use the function `from comm import create_comm`.\n",
      "  self.comm = Comm(**args)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "28b23628d5f84f5bbdb89646c5b96ccd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Batches:   0%|          | 0/12 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "tp = TopicModelDataPreparation(\"distiluse-base-multilingual-cased\")\n",
    "training_dataset = tp.fit(text_for_contextual=unpreprocessed_documents, text_for_bow=preprocessed_documents)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f3e27d5c-4f9a-478a-a237-495fc67bc44d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch: [100/100]\t Seen Samples: [217600/222500]\tTrain Loss: 1070.975986256319\tTime: 0:00:00.437354: : 100it [00:45,  2.19it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████| 35/35 [00:00<00:00, 87.82it/s]\n"
     ]
    }
   ],
   "source": [
    "ctm = ZeroShotTM(bow_size=len(tp.vocab), contextual_size=512, n_components=5, num_epochs=100)\n",
    "ctm.fit(training_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "4d6f23e8-7706-4c32-aa5a-9672d8d1c1e1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "defaultdict(list,\n",
       "            {0: ['people',\n",
       "              'music',\n",
       "              'technology',\n",
       "              'users',\n",
       "              'mobile',\n",
       "              'digital',\n",
       "              'tv',\n",
       "              'games',\n",
       "              'net',\n",
       "              'broadband'],\n",
       "             1: ['match',\n",
       "              'injury',\n",
       "              'side',\n",
       "              'back',\n",
       "              'england',\n",
       "              'win',\n",
       "              'victory',\n",
       "              'goal',\n",
       "              'club',\n",
       "              'great'],\n",
       "             2: ['analysts',\n",
       "              'bought',\n",
       "              'oil',\n",
       "              'figures',\n",
       "              'warned',\n",
       "              'value',\n",
       "              'securities',\n",
       "              'october',\n",
       "              'payments',\n",
       "              'analyst'],\n",
       "             3: ['film',\n",
       "              'best',\n",
       "              'actress',\n",
       "              'award',\n",
       "              'director',\n",
       "              'awards',\n",
       "              'stars',\n",
       "              'band',\n",
       "              'oscars',\n",
       "              'album'],\n",
       "             4: ['mr',\n",
       "              'government',\n",
       "              'labour',\n",
       "              'election',\n",
       "              'would',\n",
       "              'party',\n",
       "              'blair',\n",
       "              'minister',\n",
       "              'brown',\n",
       "              'people']})"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ctm.get_topics()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "676814c8-d52d-48bf-a6f9-5047599919ca",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zhenya/.cache/pypoetry/virtualenvs/book-2nd-ed-tI3JII_H-py3.10/lib/python3.10/site-packages/ipywidgets/widgets/widget.py:438: DeprecationWarning: The `ipykernel.comm.Comm` class has been deprecated. Please use the `comm` module instead.For creating comms, use the function `from comm import create_comm`.\n",
      "  self.comm = Comm(**args)\n",
      "/home/zhenya/.cache/pypoetry/virtualenvs/book-2nd-ed-tI3JII_H-py3.10/lib/python3.10/site-packages/ipywidgets/widgets/widget.py:438: DeprecationWarning: The `ipykernel.comm.Comm` class has been deprecated. Please use the `comm` module instead.For creating comms, use the function `from comm import create_comm`.\n",
      "  self.comm = Comm(**args)\n",
      "/home/zhenya/.cache/pypoetry/virtualenvs/book-2nd-ed-tI3JII_H-py3.10/lib/python3.10/site-packages/ipywidgets/widgets/widget.py:438: DeprecationWarning: The `ipykernel.comm.Comm` class has been deprecated. Please use the `comm` module instead.For creating comms, use the function `from comm import create_comm`.\n",
      "  self.comm = Comm(**args)\n",
      "/home/zhenya/.cache/pypoetry/virtualenvs/book-2nd-ed-tI3JII_H-py3.10/lib/python3.10/site-packages/ipywidgets/widgets/widget.py:438: DeprecationWarning: The `ipykernel.comm.Comm` class has been deprecated. Please use the `comm` module instead.For creating comms, use the function `from comm import create_comm`.\n",
      "  self.comm = Comm(**args)\n",
      "/home/zhenya/.cache/pypoetry/virtualenvs/book-2nd-ed-tI3JII_H-py3.10/lib/python3.10/site-packages/ipywidgets/widgets/widget.py:438: DeprecationWarning: The `ipykernel.comm.Comm` class has been deprecated. Please use the `comm` module instead.For creating comms, use the function `from comm import create_comm`.\n",
      "  self.comm = Comm(**args)\n",
      "/home/zhenya/.cache/pypoetry/virtualenvs/book-2nd-ed-tI3JII_H-py3.10/lib/python3.10/site-packages/ipywidgets/widgets/widget.py:438: DeprecationWarning: The `ipykernel.comm.Comm` class has been deprecated. Please use the `comm` module instead.For creating comms, use the function `from comm import create_comm`.\n",
      "  self.comm = Comm(**args)\n",
      "/home/zhenya/.cache/pypoetry/virtualenvs/book-2nd-ed-tI3JII_H-py3.10/lib/python3.10/site-packages/ipywidgets/widgets/widget.py:438: DeprecationWarning: The `ipykernel.comm.Comm` class has been deprecated. Please use the `comm` module instead.For creating comms, use the function `from comm import create_comm`.\n",
      "  self.comm = Comm(**args)\n",
      "/home/zhenya/.cache/pypoetry/virtualenvs/book-2nd-ed-tI3JII_H-py3.10/lib/python3.10/site-packages/ipywidgets/widgets/widget.py:438: DeprecationWarning: The `ipykernel.comm.Comm` class has been deprecated. Please use the `comm` module instead.For creating comms, use the function `from comm import create_comm`.\n",
      "  self.comm = Comm(**args)\n",
      "/home/zhenya/.cache/pypoetry/virtualenvs/book-2nd-ed-tI3JII_H-py3.10/lib/python3.10/site-packages/ipywidgets/widgets/widget.py:438: DeprecationWarning: The `ipykernel.comm.Comm` class has been deprecated. Please use the `comm` module instead.For creating comms, use the function `from comm import create_comm`.\n",
      "  self.comm = Comm(**args)\n",
      "/home/zhenya/.cache/pypoetry/virtualenvs/book-2nd-ed-tI3JII_H-py3.10/lib/python3.10/site-packages/ipywidgets/widgets/widget.py:438: DeprecationWarning: The `ipykernel.comm.Comm` class has been deprecated. Please use the `comm` module instead.For creating comms, use the function `from comm import create_comm`.\n",
      "  self.comm = Comm(**args)\n",
      "/home/zhenya/.cache/pypoetry/virtualenvs/book-2nd-ed-tI3JII_H-py3.10/lib/python3.10/site-packages/ipywidgets/widgets/widget.py:438: DeprecationWarning: The `ipykernel.comm.Comm` class has been deprecated. Please use the `comm` module instead.For creating comms, use the function `from comm import create_comm`.\n",
      "  self.comm = Comm(**args)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e85ac648b5d147f2a8499045f4424ec5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Batches:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "spanish_news_piece = \"\"\"IBM anuncia el comienzo de la “era de la utilidad cuántica” y anticipa un superordenador en 2033. \n",
    "La compañía asegura haber alcanzado un sistema de computación que no se puede simular con procedimientos clásicos.\"\"\"\n",
    "testing_dataset = tp.transform([spanish_news_piece])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "4bccf6f2-b2d5-40a6-8bb1-f42fc3051cab",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  3.97it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[0.5902461 , 0.09361929, 0.14041995, 0.07586181, 0.0998529 ]],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ctm.get_doc_topic_distribution(testing_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e9601a2-efcd-46db-83fd-847617aa0e0a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
